// RUN: litert-opt -tfl-prepare-composite-funcs-tf %s -split-input-file -verify-diagnostics | FILECHECK_OPTS="" FileCheck %s

module{
func.func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes  {tf._implements = "embedding_matmul", tf._reference = "mlir"} {
  %0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
  %1 = "tf.ExpandDims"(%arg1, %0) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
  %2 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
  %3 = "tf.Const"() {value = dense<4096> : tensor<i32>} : () -> tensor<i32>
  %4 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
  %5 = "tf.Range"(%4, %3, %2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4096xi32>
  %6 = "tf.Equal"(%1, %5) : (tensor<*xi32>, tensor<4096xi32>) -> tensor<*xi1>
  %7 = "tf.Cast"(%6) : (tensor<*xi1>) -> tensor<*xf32>
  %8 = "tf.BatchMatMulV2"(%7, %arg0) {adj_x = false, adj_y = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
  func.return %8 : tensor<*xf32>
}

func.func @lstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2xf32>, %arg3: tensor<1x3xf32>, %arg4: tensor<?xf32>) -> tensor<1x?xf32> attributes  {tf._implements = "LSTMCellSimple", tf._reference = "mlir"} {
    %0 = "tf.BatchMatMulV2"(%arg3, %arg1) {adj_x = false, adj_y = false} : (tensor<1x3xf32>, tensor<3x4xf32>) -> tensor<1x4xf32>
    %1 = arith.constant dense<[[2.3, 3.4, 4.5, 5.5]]> : tensor<1x4xf32>
    %2 = "tf.Add"(%0, %1) : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32>
    %3 = tensor.cast %2 : tensor<1x4xf32> to tensor<1x?xf32>
    func.return %3 : tensor<1x?xf32>
}

func.func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2xf32>, %arg3: tensor<1x3xf32>, %arg4: tensor<2xf32>) -> tensor<1x?xf32> attributes  {tf._implements = "LayerNormalizedLstmCellSimple", tf._reference = "mlir"} {
    %0 = "tf.BatchMatMulV2"(%arg3, %arg1) {adj_x = false, adj_y = false} : (tensor<1x3xf32>, tensor<3x4xf32>) -> tensor<1x4xf32>
    %1 = arith.constant dense<[[2.3, 3.4, 4.5, 5.5]]> : tensor<1x4xf32>
    %2 = "tf.Add"(%0, %1) : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32>
    %3 = tensor.cast %2 : tensor<1x4xf32> to tensor<1x?xf32>
    func.return %3 : tensor<1x?xf32>
}

// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL:   func @embedding(
// CHECK-SAME:              [[VAL_0:%.*]]: tensor<*xf32>, [[VAL_1:%.*]]: tensor<*xi32>) -> tensor<*xf32>

// CHECK-LABEL:   attributes  {tf._implements = "embedding_lookup", tf._reference = "mlir"} {
// CHECK:           [[VAL_2:%.*]] = "tfl.embedding_lookup"([[VAL_1]], [[VAL_0]]) : (tensor<*xi32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK:           return [[VAL_2]] : tensor<*xf32>

// CHECK-LABEL:   func @lstmcellsimple(
// CHECK-SAME:                          [[VAL_0]]: tensor<1x?xf32>, [[VAL_1]]: tensor<3x4xf32>, [[VAL_3:%.*]]: tensor<2xf32>, [[VAL_4:%.*]]: tensor<1x3xf32>, [[VAL_5:%.*]]: tensor<?xf32>) -> tensor<1x?xf32>

// CHECK-LABEL:   attributes  {tf._implements = "LSTMCellSimple", tf._reference = "mlir"} {
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_1]], [[VAL_6]]) : (tensor<3x4xf32>, tensor<2xi32>) -> tensor<4x3xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<1x3xf32>, tensor<2xi32>) -> tensor<3x1xf32>
// CHECK-DAG:       [[VAL_10:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK-DAG:       [[VAL_11:%.*]] = arith.constant dense<0> : tensor<2xi64>
// CHECK-DAG:       [[VAL_12:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK:           [[VAL_13:%.*]] = "tf.Slice"([[VAL_7]], [[VAL_11]], [[VAL_12]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32>
// CHECK-DAG:       [[VAL_14:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_15:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK:           [[VAL_16:%.*]] = "tf.Slice"([[VAL_7]], [[VAL_14]], [[VAL_15]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32>
// CHECK-DAG:       [[VAL_17:%.*]] = arith.constant dense<[2, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_18:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK:           [[VAL_19:%.*]] = "tf.Slice"([[VAL_7]], [[VAL_17]], [[VAL_18]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32>
// CHECK-DAG:       [[VAL_20:%.*]] = arith.constant dense<[3, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_21:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK:           [[VAL_22:%.*]] = "tf.Slice"([[VAL_7]], [[VAL_20]], [[VAL_21]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32>
// CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<0> : tensor<2xi64>
// CHECK-DAG:       [[VAL_24:%.*]] = arith.constant dense<[1, 3]> : tensor<2xi64>
// CHECK:           [[VAL_25:%.*]] = "tf.Slice"([[VAL_7]], [[VAL_23]], [[VAL_24]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32>
// CHECK-DAG:       [[VAL_26:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_27:%.*]] = arith.constant dense<[1, 3]> : tensor<2xi64>
// CHECK:           [[VAL_28:%.*]] = "tf.Slice"([[VAL_7]], [[VAL_26]], [[VAL_27]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32>
// CHECK-DAG:       [[VAL_29:%.*]] = arith.constant dense<[2, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_30:%.*]] = arith.constant dense<[1, 3]> : tensor<2xi64>
// CHECK:           [[VAL_31:%.*]] = "tf.Slice"([[VAL_7]], [[VAL_29]], [[VAL_30]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32>
// CHECK-DAG:       [[VAL_32:%.*]] = arith.constant dense<[3, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_33:%.*]] = arith.constant dense<[1, 3]> : tensor<2xi64>
// CHECK:           [[VAL_34:%.*]] = "tf.Slice"([[VAL_7]], [[VAL_32]], [[VAL_33]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32>
// CHECK-DAG:       [[VAL_35:%.*]] = arith.constant dense<0> : tensor<1xi64>
// CHECK-DAG:       [[VAL_36:%.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK:           [[VAL_37:%.*]] = "tf.Slice"([[VAL_3]], [[VAL_35]], [[VAL_36]]) : (tensor<2xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xf32>
// CHECK-DAG:       [[VAL_38:%.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK-DAG:       [[VAL_39:%.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK:           [[VAL_40:%.*]] = "tf.Slice"([[VAL_3]], [[VAL_38]], [[VAL_39]]) : (tensor<2xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xf32>
// CHECK-DAG:       [[VAL_41:%.*]] = arith.constant dense<0.000000e+00> : tensor<1xf32>
// CHECK-DAG:       [[VAL_42:%.*]] = arith.constant dense<0.000000e+00> : tensor<1xf32>
// CHECK-DAG:       [[VAL_43:%.*]] = arith.constant dense<0> : tensor<2xi64>
// CHECK-DAG:       [[VAL_44:%.*]] = arith.constant dense<[3, 1]> : tensor<2xi64>
// CHECK:           [[VAL_45:%.*]] = "tf.Slice"([[VAL_9]], [[VAL_43]], [[VAL_44]]) : (tensor<3x1xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x1xf32>
// CHECK-DAG:       [[VAL_46:%.*]] = arith.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK-DAG:       [[VAL_47:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x3xf32>
// CHECK-DAG:       [[VAL_48:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x1xf32>
// CHECK:           [[VAL_49:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_16]], [[VAL_19]], [[VAL_13]], [[VAL_22]], [[VAL_28]], [[VAL_31]], [[VAL_25]], [[VAL_34]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_40]], [[VAL_41]], [[VAL_37]], [[VAL_42]], [[VAL_45]], [[VAL_46]], [[VAL_47]], [[VAL_48]], [[VAL_10]], [[VAL_10]], [[VAL_10]], [[VAL_10]]) 
// CHECK-SAME:      <{cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = #tfl<lstm_kernel_type_attr FULL>, proj_clip = 0.000000e+00 : f32}> ({
// CHECK:           }) : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, none, none, none, none) -> tensor<1x3xf32>
// CHECK:           [[VAL_50:%.*]] = tensor.cast [[VAL_51:%.*]] : tensor<1x3xf32> to tensor<1x?xf32>
// CHECK:           return [[VAL_50]] : tensor<1x?xf32>

// CHECK-LABEL:   func @layernormalizedlstmcellsimple(
// CHECK-SAME:                                        [[VAL_0]]: tensor<1x?xf32>, [[VAL_1]]: tensor<3x4xf32>, [[VAL_3]]: tensor<2xf32>, [[VAL_4]]: tensor<1x3xf32>, [[VAL_5]]: tensor<2xf32>) -> tensor<1x?xf32>

// CHECK-LABEL:   attributes  {tf._implements = "LayerNormalizedLstmCellSimple", tf._reference = "mlir"} {
// CHECK:           [[VAL_52:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_53:%.*]] = "tf.Transpose"([[VAL_1]], [[VAL_52]]) : (tensor<3x4xf32>, tensor<2xi32>) -> tensor<4x3xf32>
// CHECK:           [[VAL_54:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_55:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_54]]) : (tensor<1x3xf32>, tensor<2xi32>) -> tensor<3x1xf32>
// CHECK-DAG:       [[VAL_56:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK-DAG:       [[VAL_57:%.*]] = arith.constant dense<0> : tensor<2xi64>
// CHECK-DAG:       [[VAL_58:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK:           [[VAL_59:%.*]] = "tf.Slice"([[VAL_53]], [[VAL_57]], [[VAL_58]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32>
// CHECK-DAG:       [[VAL_60:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_61:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK:           [[VAL_62:%.*]] = "tf.Slice"([[VAL_53]], [[VAL_60]], [[VAL_61]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32>
// CHECK-DAG:       [[VAL_63:%.*]] = arith.constant dense<[2, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_64:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK:           [[VAL_65:%.*]] = "tf.Slice"([[VAL_53]], [[VAL_63]], [[VAL_64]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32>
// CHECK-DAG:       [[VAL_66:%.*]] = arith.constant dense<[3, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_67:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK:           [[VAL_68:%.*]] = "tf.Slice"([[VAL_53]], [[VAL_66]], [[VAL_67]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x0xf32>
// CHECK-DAG:       [[VAL_69:%.*]] = arith.constant dense<0> : tensor<2xi64>
// CHECK-DAG:       [[VAL_70:%.*]] = arith.constant dense<[1, 3]> : tensor<2xi64>
// CHECK:           [[VAL_71:%.*]] = "tf.Slice"([[VAL_53]], [[VAL_69]], [[VAL_70]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32>
// CHECK-DAG:       [[VAL_72:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_73:%.*]] = arith.constant dense<[1, 3]> : tensor<2xi64>
// CHECK:           [[VAL_74:%.*]] = "tf.Slice"([[VAL_53]], [[VAL_72]], [[VAL_73]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32>
// CHECK-DAG:       [[VAL_75:%.*]] = arith.constant dense<[2, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_76:%.*]] = arith.constant dense<[1, 3]> : tensor<2xi64>
// CHECK:           [[VAL_77:%.*]] = "tf.Slice"([[VAL_53]], [[VAL_75]], [[VAL_76]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32>
// CHECK-DAG:       [[VAL_78:%.*]] = arith.constant dense<[3, 0]> : tensor<2xi64>
// CHECK-DAG:       [[VAL_79:%.*]] = arith.constant dense<[1, 3]> : tensor<2xi64>
// CHECK:           [[VAL_80:%.*]] = "tf.Slice"([[VAL_53]], [[VAL_78]], [[VAL_79]]) : (tensor<4x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x3xf32>
// CHECK-DAG:       [[VAL_81:%.*]] = arith.constant dense<0> : tensor<1xi64>
// CHECK-DAG:       [[VAL_82:%.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK:           [[VAL_83:%.*]] = "tf.Slice"([[VAL_3]], [[VAL_81]], [[VAL_82]]) : (tensor<2xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xf32>
// CHECK-DAG:       [[VAL_84:%.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK-DAG:       [[VAL_85:%.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK:           [[VAL_86:%.*]] = "tf.Slice"([[VAL_3]], [[VAL_84]], [[VAL_85]]) : (tensor<2xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xf32>
// CHECK-DAG:       [[VAL_87:%.*]] = arith.constant dense<0.000000e+00> : tensor<1xf32>
// CHECK-DAG:       [[VAL_88:%.*]] = arith.constant dense<0.000000e+00> : tensor<1xf32>
// CHECK-DAG:       [[VAL_89:%.*]] = arith.constant dense<0> : tensor<2xi64>
// CHECK-DAG:       [[VAL_90:%.*]] = arith.constant dense<[3, 1]> : tensor<2xi64>
// CHECK:           [[VAL_91:%.*]] = "tf.Slice"([[VAL_55]], [[VAL_89]], [[VAL_90]]) : (tensor<3x1xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x1xf32>
// CHECK-DAG:       [[VAL_92:%.*]] = arith.constant dense<0.000000e+00> : tensor<3xf32>
// CHECK-DAG:       [[VAL_93:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x3xf32>
// CHECK-DAG:       [[VAL_94:%.*]] = arith.constant dense<0.000000e+00> : tensor<1x1xf32>
// CHECK-DAG:       [[VAL_95:%.*]] = arith.constant dense<0> : tensor<1xi64>
// CHECK-DAG:       [[VAL_96:%.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK:           [[VAL_97:%.*]] = "tf.Slice"([[VAL_5]], [[VAL_95]], [[VAL_96]]) : (tensor<2xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xf32>
// CHECK-DAG:       [[VAL_98:%.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK-DAG:       [[VAL_99:%.*]] = arith.constant dense<1> : tensor<1xi64>
// CHECK:           [[VAL_100:%.*]] = "tf.Slice"([[VAL_5]], [[VAL_98]], [[VAL_99]]) : (tensor<2xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xf32>
// CHECK-DAG:       [[VAL_101:%.*]] = arith.constant dense<0.000000e+00> : tensor<1xf32>
// CHECK-DAG:       [[VAL_102:%.*]] = arith.constant dense<0.000000e+00> : tensor<1xf32>
// CHECK:           [[VAL_103:%.*]] = "tfl.lstm"([[VAL_0]], [[VAL_62]], [[VAL_65]], [[VAL_59]], [[VAL_68]], [[VAL_74]], [[VAL_77]], [[VAL_71]], [[VAL_80]], [[VAL_56]], [[VAL_56]], [[VAL_56]], [[VAL_86]], [[VAL_87]], [[VAL_83]], [[VAL_88]], [[VAL_91]], [[VAL_92]], [[VAL_93]], [[VAL_94]], [[VAL_100]], [[VAL_101]], [[VAL_97]], [[VAL_102]])
// CHECK-SAME:      <{cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", kernel_type = #tfl<lstm_kernel_type_attr FULL>, proj_clip = 0.000000e+00 : f32}> ({
// CHECK:           }) : (tensor<1x?xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x0xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>, none, none, none, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<3x1xf32>, tensor<3xf32>, tensor<1x3xf32>, tensor<1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x3xf32>
// CHECK:           [[VAL_104:%.*]] = tensor.cast [[VAL_105:%.*]] : tensor<1x3xf32> to tensor<1x?xf32>
// CHECK:           return [[VAL_104]] : tensor<1x?xf32>
}

// -----

module{

// expected-warning @+1 {{we cannot fuse this lstm func because all the inputs have not ranked tensor type.}}
func.func @lstmcellsimple(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>) -> tensor<*xf32> attributes  {tf._implements = "LSTMCellSimple", tf._reference = "mlir"} {
    %0 = "tf.BatchMatMulV2"(%arg3, %arg1) {adj_x = false, adj_y = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
    %1 = arith.constant dense<[[2.3, 3.4, 4.5, 5.5]]> : tensor<1x4xf32>
    %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<1x4xf32>) -> tensor<*xf32>
    %3 = tensor.cast %2 : tensor<*xf32> to tensor<*xf32>
    func.return %3 : tensor<*xf32>
}

// expected-warning @+1 {{we cannot fuse this lstm func because all the inputs have not ranked tensor type.}}
func.func @layernormalizedlstmcellsimple(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>) -> tensor<*xf32> attributes  {tf._implements = "LayerNormalizedLstmCellSimple", tf._reference = "mlir"} {
    %0 = "tf.BatchMatMulV2"(%arg3, %arg1) {adj_x = false, adj_y = false} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
    %1 = arith.constant dense<[[2.3, 3.4, 4.5, 5.5]]> : tensor<1x4xf32>
    %2 = "tf.Add"(%0, %1) : (tensor<*xf32>, tensor<1x4xf32>) -> tensor<*xf32>
    %3 = tensor.cast %2 : tensor<*xf32> to tensor<*xf32>
    func.return %3 : tensor<*xf32>
}

}

// -----

module {
func.func @inference_standard_lstm_time_major(%arg0: tensor<?x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
  %3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_19:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK-DAG:       [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }
}

// -----

module {
func.func @inference_standard_indy_lstm_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x4xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
  %2 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32>
  %3 = "tf.Transpose"(%arg4, %2) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
  %4 = "tf.MatrixDiag"(%3) : (tensor<4x10xf32>) -> tensor<4x10x10xf32>
  %5 = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
  %6 = "tf.ConcatV2"(%4, %5) : (tensor<4x10x10xf32>, tensor<i64>) -> tensor<40x10xf32>
  %7 = "tf.BatchMatMulV2"(%1, %6) {adj_x = false, adj_y = false} : (tensor<8x8x40xf32>, tensor<40x10xf32>) -> tensor<8x8x10xf32>
  %8 = "tf.Add"(%7, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %9 = "tf.Add"(%7, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %10 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %11 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %10, %9, %10, %10, %11 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_indy_lstm_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x4xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>)
// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_31:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK-DAG:       [[VAL_33:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }

}

// -----

module {
func.func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
  %3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_19:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK-DAG:       [[VAL_21:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) <{begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }

}

// -----

module {
func.func @inference_standard_indy_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x4xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
  %2 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32>
  %3 = "tf.Transpose"(%arg4, %2) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
  %4 = "tf.MatrixDiag"(%3) : (tensor<4x10xf32>) -> tensor<4x10x10xf32>
  %5 = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
  %6 = "tf.ConcatV2"(%4, %5) : (tensor<4x10x10xf32>, tensor<i64>) -> tensor<40x10xf32>
  %7 = "tf.BatchMatMulV2"(%1, %6) {adj_x = false, adj_y = false} : (tensor<8x8x40xf32>, tensor<40x10xf32>) -> tensor<8x8x10xf32>
  %8 = "tf.Add"(%7, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %9 = "tf.Add"(%7, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %10 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %11 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %10, %9, %10, %10, %11 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_indy_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x4xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} {
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>)
// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_31:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK-DAG:       [[VAL_33:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) <{begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }

}

// -----

module {
func.func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor<?x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
  %3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<0> : tensor<1xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<?x8x8xf32>, tensor<1xi32>) -> tensor<?x8x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_10:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
// CHECK-DAG:       [[VAL_12:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_15:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK-DAG:       [[VAL_18:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_19:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_21:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_24:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_25:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }

}

// -----

module {
func.func @inference_standard_indy_lstm_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x4xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
  %2 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32>
  %3 = "tf.Transpose"(%arg4, %2) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
  %4 = "tf.MatrixDiag"(%3) : (tensor<4x10xf32>) -> tensor<4x10x10xf32>
  %5 = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
  %6 = "tf.ConcatV2"(%4, %5) : (tensor<4x10x10xf32>, tensor<i64>) -> tensor<40x10xf32>
  %7 = "tf.BatchMatMulV2"(%1, %6) {adj_x = false, adj_y = false} : (tensor<8x8x40xf32>, tensor<40x10xf32>) -> tensor<8x8x10xf32>
  %8 = "tf.Add"(%7, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %9 = "tf.Add"(%7, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %10 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %11 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %10, %9, %10, %10, %11 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_indy_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x4xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} {
// CHECK:           [[VAL_40:%.*]] = arith.constant dense<0> : tensor<1xi32>
// CHECK:           [[VAL_41:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_40]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32>
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>)
// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_31:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK-DAG:       [[VAL_33:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }

}

// -----

module {
func.func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32>
  %3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<1> : tensor<1xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_8]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_10:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_11:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_10]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
// CHECK-DAG:       [[VAL_12:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_14:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_12]], [[VAL_13]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_15:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_17:%.*]]:4 = "tf.SplitV"([[VAL_11]], [[VAL_15]], [[VAL_16]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK-DAG:       [[VAL_18:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_19:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_21:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_24:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_25:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) <{begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }

}

// -----

module {
func.func @inference_standard_indy_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x4xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32>
  %2 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } : () -> tensor<2xi32>
  %3 = "tf.Transpose"(%arg4, %2) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
  %4 = "tf.MatrixDiag"(%3) : (tensor<4x10xf32>) -> tensor<4x10x10xf32>
  %5 = "tf.Const"() { value = dense<0> : tensor<i64> } : () -> tensor<i64>
  %6 = "tf.ConcatV2"(%4, %5) : (tensor<4x10x10xf32>, tensor<i64>) -> tensor<40x10xf32>
  %7 = "tf.BatchMatMulV2"(%1, %6) {adj_x = false, adj_y = false} : (tensor<8x8x40xf32>, tensor<40x10xf32>) -> tensor<8x8x10xf32>
  %8 = "tf.Add"(%7, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %9 = "tf.Add"(%7, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32>
  %10 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %11 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %10, %9, %10, %10, %11 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_indy_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x4xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "indy_lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} {
// CHECK:           [[VAL_40:%.*]] = arith.constant dense<1> : tensor<1xi32>
// CHECK:           [[VAL_41:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_40]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32>
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x4xf32>, tensor<2xi32>) -> tensor<4x10xf32>
// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<1> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<4x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>)
// CHECK-DAG:       [[VAL_20:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_21:%.*]] = "tf.Reshape"([[VAL_15]]#0, [[VAL_20]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_22:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_23:%.*]] = "tf.Reshape"([[VAL_15]]#1, [[VAL_22]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_24:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_25:%.*]] = "tf.Reshape"([[VAL_15]]#2, [[VAL_24]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK:           [[VAL_27:%.*]] = "tf.Reshape"([[VAL_15]]#3, [[VAL_26]]) : (tensor<1x10xf32>, tensor<1xi32>) -> tensor<10xf32>
// CHECK-DAG:       [[VAL_28:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_29:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_30:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_28]], [[VAL_29]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_31:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_32:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_41]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_21]], [[VAL_23]], [[VAL_25]], [[VAL_27]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_30]]#0, [[VAL_30]]#1, [[VAL_30]]#2, [[VAL_30]]#3, [[VAL_31]], [[VAL_31]], [[VAL_1]], [[VAL_2]], [[VAL_31]], [[VAL_31]], [[VAL_31]], [[VAL_31]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = true, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}> : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32>
// CHECK-DAG:       [[VAL_33:%.*]] = arith.constant dense<[0, -1, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_34:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_35:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_36:%.*]] = "tf.StridedSlice"([[VAL_32]], [[VAL_33]], [[VAL_34]], [[VAL_35]]) <{begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64}> : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_37:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_38:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_39:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_36]], [[VAL_32]], [[VAL_37]], [[VAL_38]], [[VAL_39]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }

}

// -----

module {
func.func @inference_can_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) {
  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse} : (tensor<?x8x8xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>)
  %2 = "tf.Add"(%0, %1#1) : (tensor<f32>, tensor<?x8x10xf32>) -> tensor<?x8x10xf32>
  func.return
}

func.func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
  %3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_lstm_time_major_can_fuse([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_19:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK-DAG:       [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }

}

// -----

module {
func.func @inference_can_fuse_last_output(%arg0: tensor<?x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) {
  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse_last_output} : (tensor<?x8x8xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>)
  %2 = "tf.Add"(%0, %1#0) : (tensor<f32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  func.return
}

func.func @inference_standard_lstm_time_major_can_fuse_last_output(%arg0: tensor<?x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
  %3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %7 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  func.return %7, %4, %5, %5, %6 : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_lstm_time_major_can_fuse_last_output([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK:           [[VAL_6:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32>
// CHECK:           [[VAL_8:%.*]] = arith.constant dense<[1, 0]> : tensor<2xi32>
// CHECK:           [[VAL_9:%.*]] = "tf.Transpose"([[VAL_4]], [[VAL_8]]) : (tensor<10x40xf32>, tensor<2xi32>) -> tensor<40x10xf32>
// CHECK-DAG:       [[VAL_10:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_11:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_12:%.*]]:4 = "tf.SplitV"([[VAL_7]], [[VAL_10]], [[VAL_11]]) : (tensor<40x8xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>)
// CHECK-DAG:       [[VAL_13:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_14:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_15:%.*]]:4 = "tf.SplitV"([[VAL_9]], [[VAL_13]], [[VAL_14]]) : (tensor<40x10xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK-DAG:       [[VAL_16:%.*]] = "tf.Const"() <{value = dense<10> : tensor<4xi32>}> : () -> tensor<4xi32>
// CHECK-DAG:       [[VAL_17:%.*]] = "tf.Const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK:           [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>)
// CHECK:           [[VAL_19:%.*]] = "tfl.no_value"() <{value}> : () -> none
// CHECK:           [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) <{cell_clip = 1.000000e+01 : f32, diagonal_recurrent_tensors = false, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true}> : (tensor<?x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<?x8x10xf32>
// CHECK-DAG:       [[VAL_21:%.*]] = arith.constant dense<[-1, 0, 0]> : tensor<3xi32>
// CHECK-DAG:       [[VAL_22:%.*]] = arith.constant dense<0> : tensor<3xi32>
// CHECK-DAG:       [[VAL_23:%.*]] = arith.constant dense<1> : tensor<3xi32>
// CHECK:           [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) <{begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<?x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_25:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_26:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<8x10xf32>
// CHECK-DAG:       [[VAL_27:%.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<f32>
// CHECK:           return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:         }

}

// -----

module {
func.func @inference_standard_lstm_with_mask(%arg0: tensor<?x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>, %arg6: tensor<?x8xi1>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false", "tfshape$dim { size: -1 } dim { size: 8 }"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
  %3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
  %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
}

// CHECK:       func @inference_standard_lstm_with_mask([[ARG_0:%.*]]: tensor<?x8x8xf32>, [[ARG_1:%.*]]: tensor<8x10xf32>, [[ARG_2:%.*]]: tensor<8x10xf32>, [[ARG_3:%.*]]: tensor<8x40xf32>, [[ARG_4:%.*]]: tensor<10x40xf32>, [[ARG_5:%.*]]: tensor<40xf32>,  [[ARG_6:%.*]]: tensor<?x8xi1>) -> (tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false", "tfshape$dim { size: -1 } dim { size: 8 }"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK:         [[VAL_0:%.*]] = "tf.BatchMatMulV2"([[ARG_0]], [[ARG_3]]) <{adj_x = false, adj_y = false}> : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
// CHECK:         [[VAL_1:%.*]] = "tf.Add"([[VAL_0]], [[ARG_5]]) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
// CHECK:         [[VAL_2:%.*]] = "tf.BatchMatMulV2"([[VAL_1]], [[ARG_4]]) <{adj_x = false, adj_y = true}> : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
// CHECK:         [[VAL_3:%.*]] = "tf.Add"([[VAL_2]], [[ARG_1]]) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
// CHECK:         [[VAL_4:%.*]] = "tf.Add"([[VAL_2]], [[ARG_2]]) : (tensor<?x8x10xf32>, tensor<8x10xf32>) -> tensor<?x8x10xf32>
// CHECK:         [[VAL_5:%.*]] = "tf.Add"([[ARG_1]], [[ARG_2]]) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32>
// CHECK:         [[VAL_6:%.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32} : () -> tensor<f32>
// CHECK:         return [[VAL_5]], [[VAL_4]], [[VAL_5]], [[VAL_5]], [[VAL_6]] : tensor<8x10xf32>, tensor<?x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor<f32>
// CHECK:       }

}

// -----

module {
func.func @inference_cannot_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) {
  %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_cannot_fuse} : (tensor<?x8x8xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>)
  %2 = "tf.Add"(%0, %1#2) : (tensor<f32>, tensor<?x10xf32>) -> tensor<?x10xf32>
  func.return
}

func.func @inference_standard_lstm_time_major_cannot_fuse(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
  %3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
  %4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
  %5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}

// CHECK:        func @inference_standard_lstm_time_major_cannot_fuse([[VAL_0:%.*]]: tensor<?x8x8xf32>, [[VAL_1:%.*]]: tensor<?x10xf32>, [[VAL_2:%.*]]: tensor<?x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK:           [[VAL_6:%.*]] = "tf.BatchMatMulV2"([[VAL_0]], [[VAL_3]]) <{adj_x = false, adj_y = false}> : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
// CHECK:           [[VAL_7:%.*]] = "tf.Add"([[VAL_6]], [[VAL_5]]) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
// CHECK:           [[VAL_8:%.*]] = "tf.BatchMatMulV2"([[VAL_7]], [[VAL_4]]) <{adj_x = false, adj_y = true}> : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
// CHECK:           [[VAL_9:%.*]] = "tf.Add"([[VAL_8]], [[VAL_1]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
// CHECK:           [[VAL_10:%.*]] = "tf.Add"([[VAL_8]], [[VAL_2]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
// CHECK:           [[VAL_11:%.*]] = "tf.Add"([[VAL_1]], [[VAL_2]]) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK:           [[VAL_12:%.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32} : () -> tensor<f32>
// CHECK:           return [[VAL_11]], [[VAL_10]], [[VAL_11]], [[VAL_11]], [[VAL_12]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK:         }
}

// -----

module {
// expected-warning @+1 {{we cannot fuse this lstm func because the batch size is not fixed, please consider setting fixed batch size}}
func.func @dynamic_shape_non_fuse_standard_lstm(%arg0: tensor<?x8x8xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<?x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
  %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
  %1 = "tf.Add"(%0, %arg5) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
  %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
  %3 = "tf.Add"(%2, %arg1) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
  %4 = "tf.Add"(%2, %arg2) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
  %5 = "tf.Add"(%arg1, %arg2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
  %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  func.return %5, %4, %5, %5, %6 : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
}

// CHECK: func @dynamic_shape_non_fuse_standard_lstm(%[[VAL_0:.*]]: tensor<?x8x8xf32>, %[[VAL_1:.*]]: tensor<?x10xf32>, %[[VAL_2:.*]]: tensor<?x10xf32>, %[[VAL_3:.*]]: tensor<8x40xf32>, %[[VAL_4:.*]]: tensor<10x40xf32>, %[[VAL_5:.*]]: tensor<40xf32>) -> (tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} {
// CHECK:         %[[VAL_6:.*]] = "tf.BatchMatMulV2"(%[[VAL_0]], %[[VAL_3]]) <{adj_x = false, adj_y = false}> : (tensor<?x8x8xf32>, tensor<8x40xf32>) -> tensor<?x8x40xf32>
// CHECK:         %[[VAL_7:.*]] = "tf.Add"(%[[VAL_6]], %[[VAL_5]]) : (tensor<?x8x40xf32>, tensor<40xf32>) -> tensor<?x8x40xf32>
// CHECK:         %[[VAL_8:.*]] = "tf.BatchMatMulV2"(%[[VAL_7]], %[[VAL_4]]) <{adj_x = false, adj_y = true}> : (tensor<?x8x40xf32>, tensor<10x40xf32>) -> tensor<?x8x10xf32>
// CHECK:         %[[VAL_9:.*]] = "tf.Add"(%[[VAL_8]], %[[VAL_1]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
// CHECK:         %[[VAL_10:.*]] = "tf.Add"(%[[VAL_8]], %[[VAL_2]]) : (tensor<?x8x10xf32>, tensor<?x10xf32>) -> tensor<?x8x10xf32>
// CHECK:         %[[VAL_11:.*]] = "tf.Add"(%[[VAL_1]], %[[VAL_2]]) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK:         %[[VAL_12:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<f32>}> {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32} : () -> tensor<f32>
// CHECK:         return %[[VAL_11]], %[[VAL_10]], %[[VAL_11]], %[[VAL_11]], %[[VAL_12]] : tensor<?x10xf32>, tensor<?x8x10xf32>, tensor<?x10xf32>, tensor<?x10xf32>, tensor<f32>
// CHECK:       }
}

// -----

module {
func.func @nms_padded(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<1x10xi32>, tensor<i32>) attributes  {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} {
  %0 = "tf.Const"() {value = dense<1> : tensor<1x10xi32>} : () -> tensor<1x10xi32>
  %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<i32>
  func.return %0, %1 : tensor<1x10xi32>, tensor<i32>
}

// CHECK:       func @nms_padded(%[[VAL_119:.*]]: tensor<100x4xf32>, %[[VAL_120:.*]]: tensor<100xf32>, %[[VAL_121:.*]]: tensor<i32>, %[[VAL_122:.*]]: tensor<f32>, %[[VAL_123:.*]]: tensor<f32>, %[[VAL_124:.*]]: tensor<i1>, %[[VAL_125:.*]]: tensor<i1>, %[[VAL_126:.*]]: tensor<i1>, %[[VAL_127:.*]]: tensor<i32>) -> (tensor<1x10xi32>, tensor<i32>) attributes  {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} {
// CHECK:         %[[VAL_128:.*]], %[[VAL_129:.*]] = "tfl.non_max_suppression_v4"(%[[VAL_119]], %[[VAL_120]], %[[VAL_121]], %[[VAL_122]], %[[VAL_123]]) : (tensor<100x4xf32>, tensor<100xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>)
// CHECK:         return %[[VAL_128]], %[[VAL_129]] : tensor<1x10xi32>, tensor<i32>
// CHECK:       }
}

// -----

module {
// expected-warning @+1 {{Invalid number of results from non_max_suppression_padded_v2}}
func.func private @nms_padded_invalid_num_results(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> () attributes  {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}

// expected-warning @+1 {{Invalid number of arguments to non_max_suppression_padded_v2}}
func.func private @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> (tensor<1x10xi32>, tensor<i32>) attributes  {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}

// expected-warning @+1 {{TFLite does not support batched input for non_max_suppression_padded}}
func.func private @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<2x10xi32>, tensor<i32>) attributes  {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
}

// -----

module {
// CHECK-LABEL: func private @some_func
// CHECK-LABEL: func @func_with_call
func.func private @some_func(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes {tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c"}
func.func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> {
  %0 = func.call @some_func(%arg0) : (tensor<100xf32>) -> tensor<100xf32>
  func.return %0 : tensor<100xf32>
  }
}

// -----

module {
func.func @tflite_custom_nms(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes  {tf._implements = #tf_type.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
  %0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
  %1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
  %2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
  %3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
  func.return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
}

// CHECK-LABEL: func @tflite_custom_nms(
// CHECK-SAME:                          %[[VAL_0:.*]]: tensor<1x100x4xf32>,
// CHECK-SAME:                          %[[VAL_1:.*]]: tensor<1x100x91xf32>,
// CHECK-SAME:                          %[[VAL_2:.*]]: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes {tf._implements = "TFLite_Detection_PostProcess", tf._reference = "mlir"} {
// CHECK:         %[[VAL_3:.*]]:4 = "tfl.custom"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{custom_code = "TFLite_Detection_PostProcess", custom_option = #tfl<const_bytes : "0x6D61785F646574656374696F6E73006D61785F636C61737365735F7065725F646574656374696F6E006E756D5F636C6173736573006E6D735F73636F72655F7468726573686F6C64006E6D735F696F755F7468726573686F6C6400795F7363616C6500785F7363616C6500685F7363616C6500775F7363616C65007573655F726567756C61725F6E6D73000A217E8E465B681720313A00000C000000010000000A0000000000803F010000000A0000009A99193F0000003F5B0000000000000000000040000020410000A0400E06060E0E06060E0E0E322601">}> : (tensor<1x100x4xf32>, tensor<1x100x91xf32>, tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>)
// CHECK:         return %[[VAL_3]]#0, %[[VAL_3]]#1, %[[VAL_3]]#2, %[[VAL_3]]#3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
// CHECK:       }
}

// -----

module {
// expected-warning @+1 {{Invalid number of results from TFLite_Detection_PostProcess}}
func.func private @tflite_custom_nms_invalid_results(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) attributes  {tf._implements = #tf_type.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}

// expected-warning @+1 {{Invalid number of arguments to TFLite_Detection_PostProcess}}
func.func private @tflite_custom_nms_invalid_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes  {tf._implements = #tf_type.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, max_classes_per_detection = 1 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"}

// expected-warning @+1 {{max_classes_per_detection attribute is not set or not an integer}}
func.func private @tflite_custom_nms_missing_func_args(%arg0: tensor<1x100x4xf32>, %arg1: tensor<1x100x91xf32>, %arg2: tensor<100x4xf32>) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) attributes  {tf._implements = #tf_type.func<@"TFLite_Detection_PostProcess", {max_detections = 10 : i64, num_classes = 91 : i64, nms_score_threshold = 0.5 : f32, nms_iou_threshold = 0.6 : f32, y_scale = 5.0 : f32, x_scale = 10.0 : f32, h_scale = 1.0 : f32, w_scale = 2.0 : f32, use_regular_nms = 0 : i1}>, tf._reference = "mlir"} {
  %0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
  %1 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
  %2 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
  %3 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
  func.return %0, %1, %2, %3 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
}
}

// -----

module {
func.func @max_unpooling_2d(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2]}>} {
  %0 = "tf.Const"() {value = dense<[4, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
  %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32>
  %2 = "tf.Const"() {value = dense<0> : tensor<1x1x2x1xi32>} : () -> tensor<1x1x2x1xi32>
  %3 = "tf.Const"() {value = dense<[1, 2, 4, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
  %4 = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
  %5 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
  %6 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
  %7 = "tf.FloorDiv"(%arg1, %5) {device = ""} : (tensor<1x1x2x1xi32>, tensor<i32>) -> tensor<1x1x2x1xi32>
  %8 = "tf.FloorMod"(%7, %4) {device = ""} : (tensor<1x1x2x1xi32>, tensor<i32>) -> tensor<1x1x2x1xi32>
  %9 = "tf.FloorDiv"(%arg1, %4) {device = ""} : (tensor<1x1x2x1xi32>, tensor<i32>) -> tensor<1x1x2x1xi32>
  %10 = "tf.Pack"(%2, %9, %8, %2) {axis = 0 : i64, device = ""} : (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) -> tensor<4x1x1x2x1xi32>
  %11 = "tf.Reshape"(%10, %0) {device = ""} : (tensor<4x1x1x2x1xi32>, tensor<2xi32>) -> tensor<4x2xi32>
  %12 = "tf.Transpose"(%11, %6) {device = ""} : (tensor<4x2xi32>, tensor<2xi32>) -> tensor<2x4xi32>
  %13 = "tf.Reshape"(%arg0, %1) {device = ""} : (tensor<1x1x2x1xf32>, tensor<1xi32>) -> tensor<2xf32>
  %14 = "tf.ScatterNd"(%12, %13, %3) {device = ""} : (tensor<2x4xi32>, tensor<2xf32>, tensor<4xi32>) -> tensor<1x2x4x1xf32>
  %15 = "tf.Identity"(%14) {device = ""} : (tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
  func.return %15 : tensor<1x2x4x1xf32>
}

// CHECK-LABEL: func @max_unpooling_2d(
// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<1x1x2x1xf32>,
// CHECK-SAME:                         %[[VAL_1:.*]]: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = "MaxUnpooling2D"} {
// CHECK-NEXT:    %[[VAL_2:.*]] = "tfl.custom"(%[[VAL_0]], %[[VAL_1]]) <{custom_code = "MaxUnpooling2D", custom_option = #tfl<const_bytes : "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000">}> : (tensor<1x1x2x1xf32>, tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32>
// CHECK-NEXT:    return %[[VAL_2]] : tensor<1x2x4x1xf32>
// CHECK-NEXT:  }
}

// -----

module {
// expected-warning @+1 {{Invalid number of results from MaxUnpooling2D}}
func.func private @max_unpooling_2d_invalid_results(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> (tensor<1x2x4x1xf32>, tensor<1x2x4x1xi32>) attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2]}>}

// expected-warning @+1 {{Invalid number of arguments to MaxUnpooling2D}}
func.func private @max_unpooling_2d_invalid_args(%arg0: tensor<1x1x2x1xf32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2]}>}

// expected-warning @+1 {{Padding for MaxUnpooling2D must be 'SAME' or 'VALID'}}
func.func private @max_unpooling_2d_wrong_padding(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "NO", pool_size = [2, 2], strides = [2, 2]}>}

// expected-warning @+1 {{'pool_size' attribute for MaxUnpooling2D must be set and has size of 2}}
func.func private @max_unpooling_2d_wrong_filter(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2], strides = [2, 2]}>}

// expected-warning @+1 {{'strides' attribute for MaxUnpooling2D must be set and has size of 2}}
func.func private @max_unpooling_2d_wrong_strides(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = [2, 2, 2]}>}

// expected-warning @+1 {{'padding' attribute for MaxUnpooling2D is not set or not a string}}
func.func private @max_unpooling_2d_no_padding(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {pool_size = [2, 2], strides = [2, 2]}>}

// expected-warning @+1 {{'pool_size' attribute for MaxUnpooling2D must be set and has size of 2}}
func.func private @max_unpooling_2d_no_filter(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "SAME", strides = [2, 2]}>}

// expected-warning @+1 {{'strides' attribute for MaxUnpooling2D must be set and has size of 2}}
func.func private @max_unpooling_2d_no_strides(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2]}>}

// expected-warning @+1 {{'pool_size' attribute for MaxUnpooling2D does not contain integer values}}
func.func private @max_unpooling_2d_filter_wrong_type(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = ["a", "b"], strides = [2, 2]}>}

  // expected-warning @+1 {{'strides' attribute for MaxUnpooling2D does not contain integer values}}
func.func private @max_unpooling_2d_strides_wrong_type(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xi32>) -> tensor<1x2x4x1xf32> attributes {tf._implements = #tf_type.func<@"addons:MaxUnpooling2D", {padding = "SAME", pool_size = [2, 2], strides = ["2", "2"]}>}
}

// -----

module {
func.func @dense_image_warp(%arg0: tensor<2x4x4x1xf32>, %arg1: tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32> {
  %0 = "tf.PartitionedCall"(%arg0, %arg1) {_collective_manager_ids = [], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01\82\01\00", executor_type = "", f = @__inference_dense_image_warp} : (tensor<2x4x4x1xf32>, tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32>
  func.return %0 : tensor<2x4x4x1xf32>
}

func.func private @__inference_dense_image_warp(%arg0: tensor<2x4x4x1xf32>, %arg1: tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32> attributes {tf._implements = "addons:DenseImageWarp"} {
  %0 = "tf.Const"() {value = dense<[[[[0.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00], [0.000000e+00, 2.000000e+00], [0.000000e+00, 3.000000e+00]], [[1.000000e+00, 0.000000e+00], [1.000000e+00, 1.000000e+00], [1.000000e+00, 2.000000e+00], [1.000000e+00, 3.000000e+00]], [[2.000000e+00, 0.000000e+00], [2.000000e+00, 1.000000e+00], [2.000000e+00, 2.000000e+00], [2.000000e+00, 3.000000e+00]], [[3.000000e+00, 0.000000e+00], [3.000000e+00, 1.000000e+00], [3.000000e+00, 2.000000e+00], [3.000000e+00, 3.000000e+00]]]]> : tensor<1x4x4x2xf32>} : () -> tensor<1x4x4x2xf32>
  %1 = "tf.Const"() {value = dense<[2, 16, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
  %2 = "tf.Const"() {value = dense<[2, 4, 4, 1]> : tensor<4xi32>} : () -> tensor<4xi32>
  %3 = "tf.Sub"(%0, %arg1) {device = ""} : (tensor<1x4x4x2xf32>, tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32>
  %4 = "tf.Reshape"(%3, %1) {device = ""} : (tensor<2x4x4x2xf32>, tensor<3xi32>) -> tensor<2x16x2xf32>
  %5 = "tf.PartitionedCall"(%arg0, %4) {_collective_manager_ids = [], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01\82\01\00", executor_type = "", f = @__inference_interpolate_bilinear} : (tensor<2x4x4x1xf32>, tensor<2x16x2xf32>) -> tensor<2x16x1xf32>
  %6 = "tf.Reshape"(%5, %2) {device = ""} : (tensor<2x16x1xf32>, tensor<4xi32>) -> tensor<2x4x4x1xf32>
  %7 = "tf.Identity"(%6) {device = ""} : (tensor<2x4x4x1xf32>) -> tensor<2x4x4x1xf32>
  func.return %7 : tensor<2x4x4x1xf32>
}

func.func private @__inference_interpolate_bilinear(%arg0: tensor<2x4x4x1xf32>, %arg1: tensor<2x16x2xf32>) -> tensor<2x16x1xf32> {
  %0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %2 = "tf.Const"() {value = dense<2> : tensor<i32>} : () -> tensor<i32>
  %3 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
  %4 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %5 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
  %6 = "tf.Const"() {value = dense<[[0], [16]]> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
  %7 = "tf.Const"() {value = dense<[32, 1]> : tensor<2xi32>} : () -> tensor<2xi32>
  %8 = "tf.Const"() {value = dense<4> : tensor<i32>} : () -> tensor<i32>
  %9 = "tf.Reshape"(%arg0, %7) {device = ""} : (tensor<2x4x4x1xf32>, tensor<2xi32>) -> tensor<32x1xf32>
  %10:2 = "tf.Unpack"(%arg1) {axis = 2 : i64, device = ""} : (tensor<2x16x2xf32>) -> (tensor<2x16xf32>, tensor<2x16xf32>)
  %11 = "tf.Floor"(%10#0) {device = ""} : (tensor<2x16xf32>) -> tensor<2x16xf32>
  %12 = "tf.Maximum"(%0, %11) {device = ""} : (tensor<f32>, tensor<2x16xf32>) -> tensor<2x16xf32>
  %13 = "tf.Minimum"(%12, %4) {device = ""} : (tensor<2x16xf32>, tensor<f32>) -> tensor<2x16xf32>
  %14 = "tf.Cast"(%13) {Truncate = false, device = ""} : (tensor<2x16xf32>) -> tensor<2x16xi32>
  %15 = "tf.AddV2"(%14, %3) {device = ""} : (tensor<2x16xi32>, tensor<i32>) -> tensor<2x16xi32>
  %16 = "tf.Mul"(%15, %8) {device = ""} : (tensor<2x16xi32>, tensor<i32>) -> tensor<2x16xi32>
  %17 = "tf.AddV2"(%16, %6) {device = ""} : (tensor<2x16xi32>, tensor<2x1xi32>) -> tensor<2x16xi32>
  %18 = "tf.Mul"(%14, %8) {device = ""} : (tensor<2x16xi32>, tensor<i32>) -> tensor<2x16xi32>
  %19 = "tf.AddV2"(%18, %6) {device = ""} : (tensor<2x16xi32>, tensor<2x1xi32>) -> tensor<2x16xi32>
  %20 = "tf.Sub"(%10#0, %13) {device = ""} : (tensor<2x16xf32>, tensor<2x16xf32>) -> tensor<2x16xf32>
  %21 = "tf.Maximum"(%0, %20) {device = ""} : (tensor<f32>, tensor<2x16xf32>) -> tensor<2x16xf32>
  %22 = "tf.Minimum"(%21, %1) {device = ""} : (tensor<2x16xf32>, tensor<f32>) -> tensor<2x16xf32>
  %23 = "tf.ExpandDims"(%22, %2) {device = ""} : (tensor<2x16xf32>, tensor<i32>) -> tensor<2x16x1xf32>
  %24 = "tf.Floor"(%10#1) {device = ""} : (tensor<2x16xf32>) -> tensor<2x16xf32>
  %25 = "tf.Maximum"(%0, %24) {device = ""} : (tensor<f32>, tensor<2x16xf32>) -> tensor<2x16xf32>
  %26 = "tf.Minimum"(%25, %4) {device = ""} : (tensor<2x16xf32>, tensor<f32>) -> tensor<2x16xf32>
  %27 = "tf.Cast"(%26) {Truncate = false, device = ""} : (tensor<2x16xf32>) -> tensor<2x16xi32>
  %28 = "tf.AddV2"(%27, %3) {device = ""} : (tensor<2x16xi32>, tensor<i32>) -> tensor<2x16xi32>
  %29 = "tf.AddV2"(%17, %28) {device = ""} : (tensor<2x16xi32>, tensor<2x16xi32>) -> tensor<2x16xi32>
  %30 = "tf.GatherV2"(%9, %29, %5) {batch_dims = 0 : i64, device = ""} : (tensor<32x1xf32>, tensor<2x16xi32>, tensor<i32>) -> tensor<2x16x1xf32>
  %31 = "tf.AddV2"(%19, %28) {device = ""} : (tensor<2x16xi32>, tensor<2x16xi32>) -> tensor<2x16xi32>
  %32 = "tf.GatherV2"(%9, %31, %5) {batch_dims = 0 : i64, device = ""} : (tensor<32x1xf32>, tensor<2x16xi32>, tensor<i32>) -> tensor<2x16x1xf32>
  %33 = "tf.AddV2"(%17, %27) {device = ""} : (tensor<2x16xi32>, tensor<2x16xi32>) -> tensor<2x16xi32>
  %34 = "tf.GatherV2"(%9, %33, %5) {batch_dims = 0 : i64, device = ""} : (tensor<32x1xf32>, tensor<2x16xi32>, tensor<i32>) -> tensor<2x16x1xf32>
  %35 = "tf.Sub"(%30, %34) {device = ""} : (tensor<2x16x1xf32>, tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  %36 = "tf.AddV2"(%19, %27) {device = ""} : (tensor<2x16xi32>, tensor<2x16xi32>) -> tensor<2x16xi32>
  %37 = "tf.GatherV2"(%9, %36, %5) {batch_dims = 0 : i64, device = ""} : (tensor<32x1xf32>, tensor<2x16xi32>, tensor<i32>) -> tensor<2x16x1xf32>
  %38 = "tf.Sub"(%32, %37) {device = ""} : (tensor<2x16x1xf32>, tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  %39 = "tf.Sub"(%10#1, %26) {device = ""} : (tensor<2x16xf32>, tensor<2x16xf32>) -> tensor<2x16xf32>
  %40 = "tf.Maximum"(%0, %39) {device = ""} : (tensor<f32>, tensor<2x16xf32>) -> tensor<2x16xf32>
  %41 = "tf.Minimum"(%40, %1) {device = ""} : (tensor<2x16xf32>, tensor<f32>) -> tensor<2x16xf32>
  %42 = "tf.ExpandDims"(%41, %2) {device = ""} : (tensor<2x16xf32>, tensor<i32>) -> tensor<2x16x1xf32>
  %43 = "tf.Mul"(%42, %38) {device = ""} : (tensor<2x16x1xf32>, tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  %44 = "tf.AddV2"(%43, %37) {device = ""} : (tensor<2x16x1xf32>, tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  %45 = "tf.Mul"(%42, %35) {device = ""} : (tensor<2x16x1xf32>, tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  %46 = "tf.AddV2"(%45, %34) {device = ""} : (tensor<2x16x1xf32>, tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  %47 = "tf.Sub"(%46, %44) {device = ""} : (tensor<2x16x1xf32>, tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  %48 = "tf.Mul"(%23, %47) {device = ""} : (tensor<2x16x1xf32>, tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  %49 = "tf.AddV2"(%48, %44) {device = ""} : (tensor<2x16x1xf32>, tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  %50 = "tf.Identity"(%49) {device = ""} : (tensor<2x16x1xf32>) -> tensor<2x16x1xf32>
  func.return %50 : tensor<2x16x1xf32>
}

// CHECK-LABEL: func private @__inference_dense_image_warp(
// CHECK-SAME:      %arg0: tensor<2x4x4x1xf32>,
// CHECK-SAME:      %arg1: tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32> attributes {tf._implements = "DenseImageWarp"} {
// CHECK-NEXT:    %0 = "tfl.custom"(%arg0, %arg1) <{custom_code = "DenseImageWarp", custom_option = #tfl<const_bytes : "0x">}> : (tensor<2x4x4x1xf32>, tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32>
// CHECK-NEXT:    return %0 : tensor<2x4x4x1xf32>
// CHECK-NEXT:  }
}

// -----

module {
// expected-warning @+1 {{Invalid number of arguments to DenseImageWarp}}
func.func private @dense_image_warp_invalid_inputs(%arg0: tensor<2x4x4x1xf32>) -> tensor<2x4x4x1xf32> attributes {tf._implements = "addons:DenseImageWarp"}

// expected-warning @+1 {{Image should be a 4D float tensor}}
func.func private @dense_image_warp_invalid_input_shape(%arg0: tensor<2x4x4xf32>, %arg1: tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32> attributes {tf._implements = "addons:DenseImageWarp"}

// expected-warning @+1 {{Flow should be a 4D float tensor}}
func.func private @dense_image_warp_invalid_flow_shape(%arg0: tensor<2x4x4x1xf32>, %arg1: tensor<2x4x4xf32>) -> tensor<2x4x4x1xf32> attributes {tf._implements = "addons:DenseImageWarp"}

// expected-warning @+1 {{Output should be a 4D float tensor}}
func.func private @dense_image_warp_invalid_output_shape(%arg0: tensor<2x4x4x1xf32>, %arg1: tensor<2x4x4x2xf32>) -> tensor<2x4x4xf32> attributes {tf._implements = "addons:DenseImageWarp"}

// expected-warning @+1 {{Image should be a 4D float tensor}}
func.func private @dense_image_warp_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> attributes {tf._implements = "addons:DenseImageWarp"}

// expected-warning @+1 {{Image should be a 4D float tensor}}
func.func private @dense_image_warp_invalid_input_type(%arg0: tensor<2x4x4x1xi32>, %arg1: tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xf32> attributes {tf._implements = "addons:DenseImageWarp"}

// expected-warning @+1 {{Flow should be a 4D float tensor}}
func.func private @dense_image_warp_invalid_flow_type(%arg0: tensor<2x4x4x1xf32>, %arg1: tensor<2x4x4x2xi32>) -> tensor<2x4x4x1xf32> attributes {tf._implements = "addons:DenseImageWarp"}

// expected-warning @+1 {{Output should be a 4D float tensor}}
func.func private @dense_image_warp_invalid_output_type(%arg0: tensor<2x4x4x1xf32>, %arg1: tensor<2x4x4x2xf32>) -> tensor<2x4x4x1xi32> attributes {tf._implements = "addons:DenseImageWarp"}
}

// -----

module {
func.func @my_composite_op_150(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) attributes {tf._implements = #tf_type.func<@my_composite_op, {example_option = 10 : i64, example_str = "value 1.01", tfl_fusable_op = true}>} {
  %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<*xf32>
  %1 = "tf.Identity"(%0) {device = ""} : (tensor<*xf32>) -> tensor<*xf32>
  %2 = "tf.Mul"(%0, %arg2) {device = ""} : (tensor<*xf32>, tensor<4x4xf32>) -> tensor<*xf32>
  %3 = "tf.Identity"(%2) {device = ""} : (tensor<*xf32>) -> tensor<*xf32>
  func.return %1, %3 : tensor<*xf32>, tensor<*xf32>
}

// CHECK-LABEL: func @my_composite_op_150(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) attributes {tf._implements = #tf_type.func<@my_composite_op, {example_option = 10 : i64, example_str = "value 1.01", tfl_fusable_op = true}>} {
// CHECK-NEXT:  %0:2 = "tfl.custom"(%arg0, %arg1, %arg2) <{custom_code = "my_composite_op", custom_option = #tfl<const_bytes : "0x6578616D706C655F6F7074696F6E006578616D706C655F737472000A76616C756520312E30310002281A0201020A120414042401">}> : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>)
// CHECK-NEXT:  return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32>
// CHECK-NEXT: }

}
